{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from mxnet import gluon, autograd, nd\n",
    "import numpy as np\n",
    "import pickle as p\n",
    "import matplotlib.pyplot as plt\n",
    "from time import time\n",
    "%matplotlib inline\n",
    "ctx = mx.gpu()\n",
    "data_dir = '/home/sinyer/python/data/cifar10'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_cifar(route = data_dir+'/cifar-10-batches-py'):\n",
    "    def load_batch(filename):\n",
    "        with open(filename, 'rb')as f:\n",
    "            data_dict = p.load(f, encoding='latin1')\n",
    "            X = data_dict['data']\n",
    "            Y = data_dict['labels']\n",
    "            X = X.reshape(10000, 3, 32,32).astype(\"float\")\n",
    "            Y = np.array(Y)\n",
    "            return X, Y\n",
    "    def load_labels(filename):\n",
    "        with open(filename, 'rb') as f:\n",
    "            label_names = p.load(f, encoding='latin1')\n",
    "            names = label_names['label_names']\n",
    "            return names\n",
    "    label_names = load_labels(route + \"/batches.meta\")\n",
    "    x1, y1 = load_batch(route + \"/data_batch_1\")\n",
    "    x2, y2 = load_batch(route + \"/data_batch_2\")\n",
    "    x3, y3 = load_batch(route + \"/data_batch_3\")\n",
    "    x4, y4 = load_batch(route + \"/data_batch_4\")\n",
    "    x5, y5 = load_batch(route + \"/data_batch_5\")\n",
    "    test_pic, test_label = load_batch(route + \"/test_batch\")\n",
    "    train_pic = np.concatenate((x1, x2, x3, x4, x5))\n",
    "    train_label = np.concatenate((y1, y2, y3, y4, y5))\n",
    "    return train_pic, train_label, test_pic, test_label\n",
    "\n",
    "def accuracy(output, label):\n",
    "    return nd.mean(output.argmax(axis=1)==label).asscalar()\n",
    "\n",
    "def evaluate_accuracy(data_iterator, net, ctx):\n",
    "    acc = 0.\n",
    "    for data, label in data_iterator:\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        output = net(data)\n",
    "        acc += accuracy(output, label)\n",
    "    return acc / len(data_iterator)\n",
    "\n",
    "def relu(x):\n",
    "    return nd.maximum(x, 0)\n",
    "\n",
    "def net(x):\n",
    "    x = x.reshape((-1, inp))\n",
    "    h1 = relu(nd.dot(x, w1) + b1)\n",
    "    output = nd.dot(h1, w2) + b2\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_pic, train_label, test_pic, test_label = load_cifar()\n",
    "\n",
    "batch_size = 128\n",
    "train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(\n",
    "    train_pic.astype('float32')/255, train_label.astype('float32')), batch_size, shuffle=True)\n",
    "test_data = gluon.data.DataLoader(gluon.data.ArrayDataset(\n",
    "    test_pic.astype('float32')/255, test_label.astype('float32')), batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "n, inp, h, out = 50000, 3072, 128, 10\n",
    "weight_scale = .01\n",
    "\n",
    "w1 = nd.random_normal(shape=(inp, h), scale=weight_scale, ctx=ctx)\n",
    "b1 = nd.zeros(h, ctx=ctx)\n",
    "w2 = nd.random_normal(shape=(h, out), scale=weight_scale, ctx=ctx)\n",
    "b2 = nd.zeros(out, ctx=ctx)\n",
    "\n",
    "params = [w1, b1, w2, b2]\n",
    "for param in params:\n",
    "    param.attach_grad()\n",
    "\n",
    "loss = gluon.loss.SoftmaxCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E 0; L 2.235740; Tr_acc 0.162408; Te_acc 0.231309; T 1.237913\n",
      "E 20; L 1.553429; Tr_acc 0.456166; Te_acc 0.448873; T 0.941818\n",
      "E 40; L 1.408546; Tr_acc 0.508336; Te_acc 0.487144; T 0.993180\n",
      "E 60; L 1.314173; Tr_acc 0.540509; Te_acc 0.496341; T 1.290730\n",
      "E 80; L 1.238634; Tr_acc 0.568514; Te_acc 0.515724; T 1.035585\n",
      "Tr_acc 0.590078; Te_acc 0.527591\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAIABJREFUeJzt3Xl8XXWd//HX5y7ZtyZpuiTdaWlL\nKVssmyP8BKSgUgcdBEcR5WedcRwdtxEfzsMFfzoqv8EZldFhHH4wzAwK6mBHi4zDJgiUpmylLaWh\nS5p0yb4vd/v+/vje0jSk5LZNSO/J+/l43Edyzj333M/JSd/32+/5nm/MOYeIiARLaLILEBGR8adw\nFxEJIIW7iEgAKdxFRAJI4S4iEkAKdxGRABoz3M3sDjNrNrOXjvK8mdn3zazezF40s7PHv0wRETkW\nmbTc7wRWv8HzVwCL04+1wI9OvCwRETkRY4a7c+73QPsbbLIG+FfnPQ2Umdms8SpQRESOXWQc9lEN\n7B223Jhet3/khma2Ft+6p7Cw8JylS5eOw9uLiEwdmzZtanXOTR9ru/EIdxtl3ahzGjjnbgduB6it\nrXV1dXXj8PYiIlOHme3JZLvxGC3TCMwZtlwD7BuH/YqIyHEaj3BfB1yfHjVzHtDlnHtdl4yIiLx5\nxuyWMbN7gIuBSjNrBL4KRAGccz8G1gNXAvVAP/CRiSpWREQyM2a4O+euG+N5B/zFuFUkIiInTHeo\niogEkMJdRCSAFO4iIgGkcBcRCSCFu4hIACncRUQCSOEuIhJACncRkQBSuIuIBJDCXUQkgBTuIiIB\npHAXEQkghbuISAAp3EVEAkjhLiISQAp3EZEAUriLiASQwl1EJIAU7iIiAaRwFxEJIIW7iEgAKdxF\nRAJI4S4iEkAKdxGRAFK4i4gEkMJdRCSAFO4iIgGkcBcRCSCFu4hIACncRUQCSOEuIhJACncRkQBS\nuIuIBJDCXUQkgBTuIiIBlFG4m9lqM9tuZvVmdtMoz881s0fM7Dkze9HMrhz/UkVEJFNjhruZhYHb\ngCuA5cB1ZrZ8xGZ/A9zrnDsLuBb4x/EuVEREMpdJy30VUO+c2+mciwE/BdaM2MYBJenvS4F941ei\niIgcq0zCvRrYO2y5Mb1uuK8BHzSzRmA98Jej7cjM1ppZnZnVtbS0HEe5IiKSiUzC3UZZ50YsXwfc\n6ZyrAa4E7jaz1+3bOXe7c67WOVc7ffr0Y69WREQykkm4NwJzhi3X8PpulxuBewGcc08BeUDleBQo\nIiLHLpNw3wgsNrMFZpaDv2C6bsQ2DcAlAGa2DB/u6ncREZkkY4a7cy4BfBJ4ENiGHxWzxcxuNrOr\n0pt9DviYmb0A3APc4Jwb2XUjIiJvkkgmGznn1uMvlA5f95Vh328FLhzf0kRE5HjpDlURkQBSuIuI\nBJDCXUQkgBTuIiIBpHAXEQkghbuISAAp3EVEAkjhLiISQAp3EZEAUriLiASQwl1EJIAU7iIiAaRw\nFxEJIIW7iEgAKdxFRAJI4S4iEkAKdxGRAFK4i4gEkMJdRCSAFO4iIgGkcBcRCSCFu4hIACncRUQC\nSOEuIhJACncRkQBSuIuIBJDCXUQkgBTuIiIBpHAXEQkghbuISAAp3EVEAkjhLiISQAp3EZEAUriL\niARQRuFuZqvNbLuZ1ZvZTUfZ5hoz22pmW8zsP8a3TBERORaRsTYwszBwG3AZ0AhsNLN1zrmtw7ZZ\nDHwJuNA512FmVRNVsIiIjC2TlvsqoN45t9M5FwN+CqwZsc3HgNuccx0Azrnm8S1TRESORSbhXg3s\nHbbcmF433BJgiZn9wcyeNrPVo+3IzNaaWZ2Z1bW0tBxfxSIiMqYxu2UAG2WdG2U/i4GLgRrgcTNb\n4ZzrPOJFzt0O3A5QW1s7ch8iIsfFOcdAPElHf5yu/jgAOREjJxymMDdMcV6UnIhvyw4lkvQOJugc\niNPeF6OtN0Y8mSI/GiYvGiYvGnrtayQUIpFyJFIpnINo2IiEQpjBQDzJQCxJfyxJZ3+czoEYvYMJ\nciIh8qNhciIhEknHUCJJLOkoyAlTkhelOC/CqTOLmVGSN6E/k0zCvRGYM2y5Btg3yjZPO+fiwC4z\n244P+43jUqWITBndg3Huf66JPW39NHb009IzRCQcIjcdzm29MZp7hmjvG8IBIfPtz2TqjduLedEQ\nqRTEkqmJPoQx/Z/3rOCD582b0PfIJNw3AovNbAHQBFwLfGDENvcD1wF3mlklvptm53gWKiLBF0+m\n+NhddWzY1U5+NEzNtHyqSnJJphx9QwmSDmaW5rGyppSKohzCZiSdI+WgJC/KtIIoZQVRAGJJRyyR\nom8oQfdAnO7BOOFQiKLcMIW5EaYV5FBe6B85kRCD8SSD8VT6a5LBRIpEMkUkHCIS8h8giZQjkfSt\n+PycMPnRMPk5YcoKokwryKEwN0I8kWIgnmQokSIn4j+UoqEQ/fEE3QMJugfjzJlWMOE/yzHD3TmX\nMLNPAg8CYeAO59wWM7sZqHPOrUs/9w4z2wokgS8459omsnARCZ5v/mYbG3a183d/cgZXn12N2Wi9\nwie5XJg2yupSoswqffPKyKTljnNuPbB+xLqvDPveAZ9NP0REjtl9dXu588nd3PjWBbz3nJrJLifr\nZRTuIiInIpVy7GztpW53B71DCRZVFbG4qoj8aJiXD/SwuamLW3/3ChcsquBLVyyd7HIDQeEuIics\nmXI0dQywp72PhvZ+9rYP0NIzRNdAjM7+OPUtvXSmR7EczfJZJfzwA2cTCWtWlPGgcBcREskU+7sG\n2d81iHOOSNgImeHwwwyTKb9NLJliKJHiQNdgOsT72dXax562/iNGoUTDxvSiXMoKcigriPKO5TOo\nnVfOOfOnMa0gh/rmXnY09zAQS7J0ZgnLZhVTUZQ7eT+AAFK4i0wR8WSK+uZetuzrZkdzDwfSYb6/\na4B9nYNjDiUcKS8aYs60AuZXFvL2ZVUsqChkfmUhc8sLmFGSRzh09IuhqxaUs2pB+YkekrwBhbtI\nFnLO0dgxwMbd7Wzc3UFb7xAl+VFK8qLkRkPEE76V3TuY4ED3IAe6BmnsHCCW8K3rnHCIGaW5zCzJ\n48w507jqjHzmlhcwqzSfcMhIphzJlMMMwiHfio+EjGgkRE44RFVJLtOLcrNzNMsUoXAXOUnFkyk2\nN3XxanMvTZ0DNHUMcKB7kIPpsO4eTABQnBdhdmk+PYNxugcTDCWS5IRDRCMhCnMizCzNY9nsEi5b\nPoPls0s4bXYJCyqL3rBlLdlP4S7yJmrtHXrtzsv9XYMARMMhomEjlXLEkikGYileaOzkmV3t9A4l\nXnttVXEus8rymV9RyHkLK1g0vYi3zC/n1JnFCmp5HYW7yAlwztHQ3s/mpi5aeoY4c04ZK6pLiYZD\n9A4lqNvdzsbd7bzU1M2Wfd209g5ltN8FlYWsOXM2Fyyq5LTZJcwqyyM3Ep7go5EgUbiLZMg5R/dg\ngvrmHjbsamfjrnaebeika+DIIX4FOWHmlhewo7mXZMoRCRmnVBVx0ZLpLJtVzILKQuaUFzC7LJ+w\nGbFkingyRdh8n3Y0bAryk00qCckYRPMnu5KMKdxF0oa3wjc3dbGvc5CugTjd6dkDm3sGGYwfHu53\nSlURV6yYycqaMlbWlFJZlMumPR1s2NXGrtY+Ll02g/MWVnDOvGnk5xw9rPNRkB/BOehsgN6D0NsM\nfc3Q1+of/W0Q64WhXkjFYek74awPQUE5JBPwygPw0i8gvxxmnwWzVkIiBh27oXM39LfDYDcMdfuw\ndg5cyu+zrxX6WsBCUDbXP0IRaH0FWndAKgEzT4d5F8Ccc2HGCihfAKH0+Rvsgp6Dft9D3f692nf6\n13fs8dtF8yFaALUfgVMundAfo/mZA958tbW1rq6ublLeW6aOZMpR39zL5qYuOvpi9MUSr03T2h9L\nMhBP0NoT42CPv0g5lB5NEg0bNdMK0iNQ/CRTM0pyqSrOY25FAbXzpgVzXHbPQSieMXnv39kA//Vp\nePXh1z+XVwoFFZBbDDlFEO+Hfc9BJB+WXA57N0DPfiisgsSgD9iRcksOvz6S64PcDKKFUFjpH6kk\ndO31tSRjULnEPyK50LABmur8/gEieVBS7T+EYj2jHJBB6Rwon+8/SOIDvu63fQFWXH1cPyIz2+Sc\nqx1rO7XcJTB2tfaxfvN+9ncNvDYt7Lb93fTHkkdslxMJUZgTpiAnQn5OmPLCHFbWlPGO5bksqCxi\nZU0pS2YUvzb/d+DseQq2/KcPmKLpfp1z8PA34PG/g3M+Ald8FyI5R74umYCBDv+I5vvWcrTAt0xf\nfQR2PgqJASicDgWVcOpqWHhxZjUlE/DsnfC7r/paLvkKzFwJRVWH9zeyHoADL8Ez/wRbfwU1q+Cd\nt8Lid/jQbt8JB17wwT1tvm+J54zDbIyJGDRvgYNboXkrdDVC8UwomQ3FsyG/zH+A5JVC2bzxec/j\noJa7ZI2hRJKt+7p5YW8nu1r7mFaYw8ySPEIh4/7nmnjyVT8R6aFpXCsKc1g2q4Qz5pRyenUZM0py\nyY+GJ+/29vad8PJvYMHbYNYZx7+ftlfhmdt9a7FiEVScAhWLoXIx5JUc/XWpFDxxKzzyLXBJ3+K8\n5m6oPht++yXY8COoPgeaNsHc8+Gaf/XdFZvuhBfv9a3ikUJR3z0CUL7IB35fq+9SiffDRTfBRV+E\nUAi698HD3/T7zyvx4ZdKQscu30pOJfyHwbu/D9Mmdq7zbJZpy13hLiedgViS+uZeXjnYQ31LLztb\netnZ0sfutj7iSf/7WpQbOWKYYHVZPtetmsM1tXOoGu+/cOOcf4Qy/FBorINn7/JhN/ssH2RP/wg2\n3+f7d8G3Ms/+EAx0QuNG3/o978/hnBtG32cqBbsfhw3/BNvXQzjqW4qdDYf3CVA8C+aeBxd8yof2\nIa318MAXfHfHivdC7Y1w/59BzwGYdyHsfATO+wRc/i3fZ/2rT0I4B4a6wMKwZLXvv84vh/xpvoXe\n3w4D7VC+EBb+ryMDOT4Av/4MvHCPf231OfDE93yAL3q7f36wEzDfb12+0P+slr7Ld5PIUSnc5aTW\n0Rdj24Futh/o4ZWDPTS099PWG6O1N0Zb3xCHfi2jYWNueQELpxexaHoRZ9SUcubcMmaV5hNLpGju\nGaRnMMGpM4oJjfdY72QcXvyZ76roa4VTr4TT3gOzzvR9sh27fbAuWe3/K+6cb1E/+GV/IS4xcHhf\n0QKo/SicfT3UPwQb/9m35MH/1z23GA6+BFfcAueuPfy65m3w/H/4wO1u8uH6lv8Nqz7muywOXSxs\n2wEt2/3jlQf8xb1Fl8D8C2HrOtj/PIRz4Yrv+A8QMx/Ov7jRB/4ffR7e/jeHg3Xf8/DYd/yxnv0h\n/0FyrJyDjT+B397kQ33Zu+Gyb/gwl+OmcJdJNxhPcrB7kKaOARra+2lo7+eVg71s3dfFvvQNPADT\nCqIsqCyksiiXiiJ/S/ySGUUsnlHM/IqCie1GScSg94AfhdHf7r/2tfgujy2/9C3jmSv9KImXf5Nu\nbY4QzvWjNlzS9/0uWQ1//GMfbvue832yS98FhRWHX5NKwYEXfddI0XRfx303wPbfwOV/6y/gPfUD\n348diviRFSuv8R8wYw3HG+z2ofrUbdDf6gP69D/xLfaSWUdum0r6kSBVEzjN7v4XID4Ic8+duPeY\nQhTuMqFiiRRtfUMkkn7+Eedgy75untnVTt2edva09b9u/HckZMyvLOS02SUsn1XCslklLJ1ZzPTi\nCZijZN9zvgUdzYeqZVB1mu+mqFjsu1f6Wn0XxzO3jx7YoYjvSvijz6Uv0JkP4F2P+ZZy2TzfDTHU\nCy/+FDb/3O/n7X8DF34m8y6c4RIx+MVHYdt/+eXiWXDux+Gs64/8YMhUrN9f/CytPvbXyklL4S7j\nqqVniPufa+LXm/fT2N5PW19s1O1yIyHOnFPGkhnFzCzNo6o4l+ppflKqmSV5o7fCew74i2wu5R/R\nAt8HWzbX9y0P9fhxwp17/EW5rkbfwj40NK6oyl+kLJvr9/fs3fCbz/mLe4WVvqsima43r8y3whvT\nw9mWvtMPoyuo9NsXVPogzSs7tr7fRMzXeTwhPFwyDo/f6j84Trt69BEiMqVpKKScEOccO1v7eOrV\nNh7d3sKj25tJpBxnzinj8hUzmVGcx/TiXKJh89cbcZxSVcSK6tIj76481HgYLSgTQ/D0P8Jjt0C8\n7/XPhyK+L3qgY8T6qA/1oW4/IuOQGadD2Rx/wXHhxfDeO3zYJhPQVu/HJ+/d4PuTT38vXPBpmL7k\nRH9UXiQHIicY7OA/zC7+4onvR6Y8hbsQTx7+4wsvNHbywt5OnmvopLnHz4MysySPG9+6gD+preGU\nquKxdzjQ4S8avvIg1P+PH1ERivobPgor/NC98oX+Ql5bve+PvuBTvgvFQn74Xdur0P6q31fZ3MPj\nlEvn+Nb1oW6PWL/vF9/xILy83u/zrZ/13SOH7hwMR3yfctVSOOuDE/NDFDnJqFtmCukbSrC7rY/6\n5l627utm6/5udhzs5WDPIMN/DeZVFHBGTRnnLizn/IUVLCiMY3mlhwO1q9H3V2/+uR8et2qtH97W\n2QBP/dB3iyQGfOv6lMt8F0NiyHeN9Bzwgd6+04/AuPxvYfE43obtnIbSSaCpW0Zo7R3i4ZebeWjb\nQZ5t6KSl5/CMhDnhEEtmFnHhKZVUT8tndmkeNdMKWFFdQllowN+0sudJ+MMGPwQvkg+Vp/jA3vU4\n4PxQu6ZN8G9X+1Z1V5Nvea98vx/yV1N7uPU80kSFsIJdBFC4B0ZzzyD/s7WZl/Z1sbe9n5b2Ttra\nW2lxZcwqzeNti6ezcHohSwv7WJLYzqy8BJF4A2Aw5y0wc5Efi1z3L/DYd31XSkm1nyBp1ko/NLB1\nh2+1n/8J31ovm+tb5Fv+048HX/pu/1xpzdgFK4RFJpTCPUulUo6t+7t5or6V/9l6kE0NHTgHZQVR\n3lW0gx8M3kpZbisD01eSd/p7sMIKfyPMoVb3SLklkFPobzFfcBFc9nV/x+BYIrlwxrX+ISInDYV7\nlmjtHeLFxk62NHWzuamLjbvb6ej348iXzSrhry5ZwuXLyjl16w+wP/y9v2i58uPkb38AHr7Z76R8\nEVz0137cdkE55Jb6vvGGp2H3E777Zc0PfXeLWtYiWU0XVE9isUSKh7Yd5N66vTz2SguH/jj9/IoC\nLqiJcPn0Ds7KO0BJd72fne7gS350yTk3+DlCcgr9C7oa/e3oVcsV2iJZThdUs0w8meKlpi427eng\nlYM97GjuZcfBXnqH4lxYdIC7Fm1neWQ/ZbH9hLsaYHszbE+/OFro78Jc9m4/rHDJ5UfuvLQms35w\nEQkMhfskcs7x8MvN3PWHnXTseYmVqa2stJ28LQqX5+dTUhnltKHnKehtgEbzQwrL5vnwrjjFB/r0\nU6F07vHd7i4igaVwnwSx3g5efeD7NO14nmmDe/lhaB8l4T4IQyq/glBOISSHoD/u50NZ9jk/YVRR\n1WSXLiJZQuH+JhlKJPlDfSvbnvotf7znZpbRSqVV4KYvpHDeBTBnFcw9n1D5QvWLi8gJU7hPoMF4\nkt9tPciDWw7w9PZGPpz8BZ+I/IqO6Gw2XnQfZ19wGeHxnoNcRASF+4RIphz3P9fEP/z3Vgq6X+Uj\neY9xS/hx8q2X5BkfpPLK71CZWzTZZYpIgCncx0sqhdv5KC2P/Zj8xj/w7tQA77Uk5IIL5WDL10Dt\nRwnPu2CyKxWRKUDhfqJSSXjuboYevZXcnj1EXBGPRc/n1MXzWTR7OqGiKh/sBeWTXamITCEK9xMw\ntPNJBtZ9nrLOLWxLLeK+8KdZesmf8v7zF5MT0dBEEZk8GYW7ma0G/gEIAz9xzn37KNu9D7gPeItz\nLpi3nzrH7rrf0v/4bSzvfpwON42v5nyWinOv46a3LqA4LzrZFYqIjB3uZhYGbgMuAxqBjWa2zjm3\ndcR2xcCngA0TUeikGejwf7mns4FdO7YQ2bGe+cm9dLhi/nv69ZRc+nm+umQuIY16EZGTSCYt91VA\nvXNuJ4CZ/RRYA2wdsd03gO8Cnx/XCifTnifhZx/yf0EemONCvBJexGPLv84Zqz/KO0pKJrlAEZHR\nZRLu1cDeYcuNwLnDNzCzs4A5zrlfm9lRw93M1gJrAebOnXvs1b6ZNt2J+83n6c2v5q/tz9gWm8E1\nl6ziYxctYflof+RZROQkkkm4j9bf8NpUkmYWAr4H3DDWjpxztwO3g58VMrMS32SpFDz4JdjwYzbn\n1fLBto+zaG41P3nfysz+fqiIyEkgk3BvBOYMW64B9g1bLgZWAI+av21+JrDOzK7KuouqyQSp+z9B\naPPPuDN1BbcOfJgvrFnGB86dpztJRSSrZBLuG4HFZrYAaAKuBT5w6EnnXBdQeWjZzB4FPp91wZ4Y\nouvfPkzp7ge4JX4NO5as5b/fczozS/MmuzIRkWM2Zrg75xJm9kngQfxQyDucc1vM7Gagzjm3bqKL\nnDBDPdBYR3zPM7Rsup/ZfVu5NfQRTr/ui3xhxazJrk5E5LhlNM7dObceWD9i3VeOsu3FJ17Wm2D/\nC3DXVTDYSRToS1Xzszlf5sY//QylBRqrLiLZbWreodpaT+ruq+lM5PJXsS/SWb6SL199Pu9fWDHZ\nlYmIjIupF+5dTcTuvIq+/hjXxL7GVZe8jY9ftJDcSHiyKxMRGTdTK9w7dtP1kzWEetv5TM7NfPej\nV3P23GmTXZWIyLibOuH+6sMM3HMDxOP8YOY3+d71H2JaYc5kVyUiMiGCH+7OwZPfJ/W7r7EnVc39\np36XL153hcati0igBT7c3RPfwx76OuuT5/LUaV/nG9ecp0m+RCTwgh3uL96HPfR17k9ewFMrv8Xf\nvvdMBbuITAnBDfddvyf5n3/GM8nlPHnazXxbwS4iU0gww73tVWL//gF2J2fwyyXf5tvX1CrYRWRK\nCV64J+P0/McNJOMp7ph3C9/6wB/p4qmITDmBm5jcPfptitte5JacT/C161cT1dzrIjIFBSv59jwF\nj9/KvYmLeMuVHyEvqrtORWRqCk64D3aT+uVamqyK+6r+gqvOmD3ZFYmITJrghPvD34CuRv5y8M/5\nzJXn6AKqiExpwQj3pmdxz/wz9/AOSpdcwAWnVI79GhGRAMv+0TKpJPz6M/TnVPDt7vdx7+qlk12R\niMiky/6We90dsP95/s4+zLL5NSybVTLZFYmITLrsDveeg/DQzXTOupA7us7m2lVzxn6NiMgUkN3h\n/sI9MNTNP+Z9nOK8KFfo756KiADZHu67HydZsYQ7d+TwnjOryc/RuHYREcjmcE/GoeFpdhScRSyR\nUpeMiMgw2Rvu+1+AWC+/aJvP6dWlnDa7dLIrEhE5aWRvuO/6PQC/bJ/P+9+iVruIyHDZG+67n6C9\ncBFtlHLl6bqQKiIyXHaGe7q/fXv+mZTmRynXH7oWETlCdob7vucg3sczbjnzKgomuxoRkZNOdoZ7\nur/9oYHFzC1XuIuIjJSd4b77CVzVcrZ2RtVyFxEZRfaFeyIGezfQO/N8EinHvPLCya5IROSkk33h\nvu9ZiPfTUHoOAHPVchcReZ3sC/fdjwPGS9EVAMyvUMtdRGSk7JvP/ZyPwqyzqH8lSm4kRFVx7mRX\nJCJy0sm+lnthBSy+lD1t/cwtL9Cf0xMRGUX2hXtaQ3u/RsqIiBxFRuFuZqvNbLuZ1ZvZTaM8/1kz\n22pmL5rZQ2Y2b/xLPcw5R0N7P3M1UkZEZFRjhruZhYHbgCuA5cB1ZrZ8xGbPAbXOuZXAz4Hvjneh\nw7X0DtEfS6rlLiJyFJm03FcB9c65nc65GPBTYM3wDZxzjzjn+tOLTwM141vmkfa0+bdSuIuIjC6T\ncK8G9g5bbkyvO5obgQdGe8LM1ppZnZnVtbS0ZF7lCIfDXd0yIiKjySTcRxuO4kbd0OyDQC1wy2jP\nO+dud87VOudqp0+fnnmVIzS09REyqC7LP+59iIgEWSbj3BuB4X8NowbYN3IjM7sU+DJwkXNuaHzK\nG92e9n5ml+WTE8nawT4iIhMqk3TcCCw2swVmlgNcC6wbvoGZnQX8E3CVc655/Ms80p42DYMUEXkj\nY4a7cy4BfBJ4ENgG3Ouc22JmN5vZVenNbgGKgPvM7HkzW3eU3Y0LDYMUEXljGU0/4JxbD6wfse4r\nw76/dJzrOqruwTjtfTG13EVE3kDWdVo3pEfKzFe4i4gcVdaF+6FhkOqWERE5uuwL9/Y+QPO4i4i8\nkayb8vd959RwZk0ZRblZV7qIyJsm6xKyqjiPquK8yS5DROSklnXdMiIiMjaFu4hIACncRUQCSOEu\nIhJACncRkQBSuIuIBJDCXUQkgBTuIiIBpHAXEQkghbuISAAp3EVEAkjhLiISQAp3EZEAUriLiASQ\nwl1EJIAU7iIiAaRwFxEJIIW7iEgAKdxFRAJI4S4iEkAKdxGRAFK4i4gEkMJdRCSAFO4iIgGkcBcR\nCSCFu4hIACncRUQCSOEuIhJACncRkQBSuIuIBFBG4W5mq81su5nVm9lNozyfa2Y/Sz+/wczmj3eh\nIiKSuTHD3czCwG3AFcBy4DozWz5isxuBDufcKcD3gO+Md6EiIpK5TFruq4B659xO51wM+CmwZsQ2\na4C70t//HLjEzGz8yhQRkWMRyWCbamDvsOVG4NyjbeOcS5hZF1ABtA7fyMzWAmvTi71mtv14igYq\nR+57ipiKxz0Vjxmm5nFPxWOGYz/ueZlslEm4j9YCd8exDc6524HbM3jPNy7IrM45V3ui+8k2U/G4\np+Ixw9Q87ql4zDBxx51Jt0wjMGfYcg2w72jbmFkEKAXax6NAERE5dpmE+0ZgsZktMLMc4Fpg3Yht\n1gEfTn//PuBh59zrWu4iIvLmGLNbJt2H/kngQSAM3OGc22JmNwN1zrl1wL8Ad5tZPb7Ffu1EFs04\ndO1kqal43FPxmGFqHvdUPGaYoOM2NbBFRIJHd6iKiASQwl1EJICyLtzHmgohCMxsjpk9YmbbzGyL\nmX06vb7czH5nZjvSX6dNdq3jzczCZvacmf06vbwgPaXFjvQUFzmTXeN4M7MyM/u5mb2cPufnT5Fz\n/Zn07/dLZnaPmeUF7Xyb2R2YELS6AAAC7UlEQVRm1mxmLw1bN+q5Ne/76Wx70czOPpH3zqpwz3Aq\nhCBIAJ9zzi0DzgP+In2cNwEPOecWAw+ll4Pm08C2YcvfAb6XPuYO/FQXQfMPwG+dc0uBM/DHH+hz\nbWbVwKeAWufcCvxgjWsJ3vm+E1g9Yt3Rzu0VwOL0Yy3woxN546wKdzKbCiHrOef2O+eeTX/fg//H\nXs2R0zzcBbxnciqcGGZWA7wT+El62YC346e0gGAecwnwNvyIM5xzMedcJwE/12kRID99b0wBsJ+A\nnW/n3O95/T0/Rzu3a4B/dd7TQJmZzTre9862cB9tKoTqSarlTZGeYfMsYAMwwzm3H/wHAFA1eZVN\niL8H/hpIpZcrgE7nXCK9HMTzvRBoAf5fujvqJ2ZWSMDPtXOuCfi/QAM+1LuATQT/fMPRz+245lu2\nhXtG0xwEhZkVAb8A/so51z3Z9UwkM3sX0Oyc2zR89SibBu18R4CzgR85584C+ghYF8xo0v3Ma4AF\nwGygEN8tMVLQzvcbGdff92wL90ymQggEM4vig/3fnXO/TK8+eOi/aemvzZNV3wS4ELjKzHbju9ve\njm/Jl6X/2w7BPN+NQKNzbkN6+ef4sA/yuQa4FNjlnGtxzsWBXwIXEPzzDUc/t+Oab9kW7plMhZD1\n0n3N/wJsc87dOuyp4dM8fBj41Ztd20Rxzn3JOVfjnJuPP68PO+f+FHgEP6UFBOyYAZxzB4C9ZnZq\netUlwFYCfK7TGoDzzKwg/ft+6LgDfb7TjnZu1wHXp0fNnAd0Heq+OS7Ouax6AFcCrwCvAl+e7Hom\n6Bjfiv/v2IvA8+nHlfg+6IeAHemv5ZNd6wQd/8XAr9PfLwSeAeqB+4Dcya5vAo73TKAufb7vB6ZN\nhXMNfB14GXgJuBvIDdr5Bu7BX1OI41vmNx7t3OK7ZW5LZ9tm/Eii435vTT8gIhJA2dYtIyIiGVC4\ni4gEkMJdRCSAFO4iIgGkcBcRCSCFu4hIACncRUQC6P8DLS69aPvfy5IAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f1dd80d9240>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lr = 0.01 / batch_size\n",
    "epochs = 100\n",
    "\n",
    "a, b = [], []\n",
    "for epoch in range(epochs):\n",
    "    if epoch > 80:\n",
    "        lr = 0.001 / batch_size\n",
    "    train_loss = 0.\n",
    "    train_acc = 0.\n",
    "    start = time()\n",
    "    for data, label in train_data:\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        with autograd.record():\n",
    "            output = net(data)\n",
    "            l = loss(output, label)\n",
    "        l.backward()\n",
    "        for param in params:\n",
    "            param[:] = param - lr * param.grad\n",
    "        train_loss += nd.mean(l).asscalar()\n",
    "        train_acc += accuracy(output, label)\n",
    "    test_acc = evaluate_accuracy(test_data, net, ctx)\n",
    "    \n",
    "    if epoch%20 == 0:\n",
    "        print(\"E %d; L %f; Tr_acc %f; Te_acc %f; T %f\" % (\n",
    "            epoch, train_loss / batch, train_acc / batch, test_acc, time() - start))\n",
    "    a.append(train_acc/len(train_data))\n",
    "    b.append(test_acc)\n",
    "    \n",
    "print(\"Tr_acc %f; Te_acc %f\" % (train_acc / batch, test_acc))\n",
    "plt.plot(np.arange(epochs), a, np.arange(epochs), b)\n",
    "plt.ylim(0,1)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
